{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# High-level RNN Gluon Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import math\n",
    "import mxnet as mx\n",
    "from mxnet import gluon\n",
    "from common.params_lstm import *\n",
    "from common.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OS:  linux\n",
      "Python:  3.6.4 |Anaconda, Inc.| (default, Jan 16 2018, 18:10:19) \n",
      "[GCC 7.2.0]\n",
      "MXNet:  1.3.0\n",
      "Numpy:  1.13.3\n",
      "GPU:  ['Tesla V100-SXM2-16GB', 'Tesla V100-SXM2-16GB', 'Tesla V100-SXM2-16GB', 'Tesla V100-SXM2-16GB']\n",
      "CUDA Version 9.1.85\n",
      "CuDNN Version  7.1.3\n"
     ]
    }
   ],
   "source": [
    "print(\"OS: \", sys.platform)\n",
    "print(\"Python: \", sys.version)\n",
    "print(\"MXNet: \", mx.__version__)\n",
    "print(\"Numpy: \", np.__version__)\n",
    "print(\"GPU: \", get_gpu_name())\n",
    "print(get_cuda_version())\n",
    "print(\"CuDNN Version \", get_cudnn_version())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "30000\n",
      "125\n",
      "100\n",
      "150\n"
     ]
    }
   ],
   "source": [
    "print(MAXFEATURES)\n",
    "print(EMBEDSIZE)\n",
    "print(NUMHIDDEN)\n",
    "print(MAXLEN)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNN(gluon.Block):\n",
    "    def __init__(self, \n",
    "                 maxf=MAXFEATURES, edim=EMBEDSIZE, nhid=NUMHIDDEN, **kwargs):\n",
    "        super(RNN, self).__init__(**kwargs)\n",
    "        self.nhid = nhid\n",
    "        with self.name_scope():\n",
    "            self.embedding = gluon.nn.Embedding(input_dim=maxf,\n",
    "                                          output_dim=edim)\n",
    "            self.gru = gluon.rnn.GRU(\n",
    "                              hidden_size=nhid, \n",
    "                              num_layers=1,\n",
    "                              layout=\"NTC\",\n",
    "                              bidirectional=False)   \n",
    "            self.l_out = gluon.nn.Dense(units=2)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.embedding(x) \n",
    "        x = self.gru(x) # default state will be all 0\n",
    "        x = x[:,-1,:].squeeze()\n",
    "        x = self.l_out(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 137,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_model(net, ctx, lr=LR, b1=BETA_1, b2=BETA_2, eps=EPS):\n",
    "    net.initialize(mx.init.Xavier(), ctx=ctx)\n",
    "    trainer = gluon.Trainer(\n",
    "        net.collect_params(), \n",
    "        'adam',\n",
    "        {'learning_rate': lr, 'beta1':BETA_1, 'beta2':BETA_2, 'epsilon':EPS}\n",
    "    )\n",
    "    criterion = gluon.loss.SoftmaxCrossEntropyLoss()\n",
    "    return trainer, criterion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preparing train set...\n",
      "Preparing test set...\n",
      "Trimming to 30000 max-features\n",
      "Padding to length 150\n",
      "(25000, 150) (25000, 150) (25000,) (25000,)\n",
      "int64 int64 int64 int64\n",
      "CPU times: user 5.63 s, sys: 248 ms, total: 5.88 s\n",
      "Wall time: 5.87 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Data into format for library\n",
    "x_train, x_test, y_train, y_test = imdb_for_library(seq_len=MAXLEN, max_features=MAXFEATURES)\n",
    "# Torch-specific\n",
    "x_train = x_train.astype(np.int64)\n",
    "x_test = x_test.astype(np.int64)\n",
    "y_train = y_train.astype(np.int64)\n",
    "y_test = y_test.astype(np.int64)\n",
    "print(x_train.shape, x_test.shape, y_train.shape, y_test.shape)\n",
    "print(x_train.dtype, x_test.dtype, y_train.dtype, y_test.dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run on one GPU\n",
    "ctx = mx.gpu(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 40 ms, sys: 0 ns, total: 40 ms\n",
      "Wall time: 3.47 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "net = RNN()\n",
    "trainer, loss_fn = init_model(net, ctx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [0], loss: 0.4858\n",
      "Epoch [1], loss: 0.2264\n",
      "Epoch [2], loss: 0.1178\n",
      "CPU times: user 15.2 s, sys: 3.01 s, total: 18.2 s\n",
      "Wall time: 11.4 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "for i in range(EPOCHS):\n",
    "    loss_acc = mx.nd.zeros((1), ctx)\n",
    "    for j, (data, target) in enumerate(yield_mb(x_train, y_train, BATCHSIZE, shuffle=True)):\n",
    "        # Get samples\n",
    "        data = mx.nd.array(data, ctx=ctx)\n",
    "        target = mx.nd.array(target, ctx=ctx)\n",
    "        # Forwards\n",
    "        with mx.autograd.record():\n",
    "            output = net(data)\n",
    "            loss = loss_fn(output, target)\n",
    "        # Back-prop\n",
    "        loss.backward()\n",
    "        loss_acc += loss.mean()\n",
    "        trainer.step(data.shape[0])\n",
    "    print(\"Epoch [{}], loss: {:.4f}\".format(i, loss_acc.asscalar()/(j+1)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 9.34 s, sys: 16.2 s, total: 25.5 s\n",
      "Wall time: 1.78 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Main evaluation loop: 1.52s\n",
    "n_samples = (y_test.shape[0]//BATCHSIZE)*BATCHSIZE\n",
    "y_guess = mx.nd.zeros((n_samples), dtype=np.int)\n",
    "y_truth = y_test[:n_samples]\n",
    "c = 0\n",
    "for data, target in yield_mb(x_test, y_test, BATCHSIZE):\n",
    "    # Get samples\n",
    "    data = mx.nd.array(data, ctx=ctx)\n",
    "    target = mx.nd.array(target, ctx=ctx)\n",
    "    # Forwards\n",
    "    output = net(data)\n",
    "    pred = output.topk(k=1).squeeze()\n",
    "    # Collect results\n",
    "    y_guess[c*BATCHSIZE:(c+1)*BATCHSIZE] = pred\n",
    "    c += 1\n",
    "mx.nd.waitall()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy:  0.857892628205\n"
     ]
    }
   ],
   "source": [
    "print(\"Accuracy: \", sum(y_guess.asnumpy() == y_truth)/len(y_guess))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Environment (conda_mxnet_p36)",
   "language": "python",
   "name": "conda_mxnet_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
